# *****************************************************************************
#  Copyright (c) 2018, NVIDIA CORPORATION.  All rights reserved.
#
#  Redistribution and use in source and binary forms, with or without
#  modification, are permitted provided that the following conditions are met:
#      * Redistributions of source code must retain the above copyright
#        notice, this list of conditions and the following disclaimer.
#      * Redistributions in binary form must reproduce the above copyright
#        notice, this list of conditions and the following disclaimer in the
#        documentation and/or other materials provided with the distribution.
#      * Neither the name of the NVIDIA CORPORATION nor the
#        names of its contributors may be used to endorse or promote products
#        derived from this software without specific prior written permission.
#
#  THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
#  ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
#  WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
#  DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
#  DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
#  (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
#  LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
#  ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
#  (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
#  SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
# *****************************************************************************

from tacotron2.text import text_to_sequence
import models
import torch
import argparse
import numpy as np
from scipy.io.wavfile import write

import sys
sys.path.append('waveglow/')

import time
from dllogger.logger import LOGGER
import dllogger.logger as dllg
from dllogger.autologging import log_hardware, log_args


def parse_args(parser):
    """
    Parse commandline arguments.
    """
    parser.add_argument('-i', '--input', type=str, default="",
                        help='full path to the input text (phareses separated by new line); \
                        if not provided then use default text')
    parser.add_argument('-o', '--output', required=True,
                        help='output folder to save audio (file per phrase)')
    parser.add_argument('--tacotron2', type=str, default="",
                        help='full path to the Tacotron2 model checkpoint file')
    parser.add_argument('--mel-file', type=str, default="",
                        help='set if using mel spectrograms instead of Tacotron2 model')
    parser.add_argument('--waveglow', required=True,
                        help='full path to the WaveGlow model checkpoint file')
    parser.add_argument('--old-waveglow', action='store_true',
                        help='set if WaveGlow checkpoint is from GitHub.com/NVIDIA/waveglow')
    parser.add_argument('-s', '--sigma-infer', default=0.6, type=float)
    parser.add_argument('-sr', '--sampling-rate', default=22050, type=int,
                        help='Sampling rate')
    parser.add_argument('--fp16-run', action='store_true',
                        help='inference in fp16')
    parser.add_argument('--log-file', type=str, default='nvlog.json',
                        help='Filename for logging')


    return parser


def load_checkpoint(checkpoint_path, model_name):
    assert os.path.isfile(checkpoint_path)

    print("Loading checkpoint '{}'".format(checkpoint_path))
    checkpoint_dict = torch.load(checkpoint_path, map_location='cpu')
    model.load_state_dict(checkpoint_dict['state_dict'])
    print("Loaded '{}' checkpoint '{}'" .format(model_name, checkpoint_path))
    return model


def checkpoint_from_distributed(state_dict):
    """
    Checks whether checkpoint was generated by DistributedDataParallel. DDP
    wraps model in additional "module.", it needs to be unwrapped for single
    GPU inference.
    :param state_dict: model's state dict
    """
    ret = False
    for key, _ in state_dict.items():
        if key.find('module.') != -1:
            ret = True
            break
    return ret


def unwrap_distributed(state_dict):
    """
    Unwraps model from DistributedDataParallel.
    DDP wraps model in additional "module.", it needs to be removed for single
    GPU inference.
    :param state_dict: model's state dict
    """
    new_state_dict = {}
    for key, value in state_dict.items():
        new_key = key.replace('module.', '')
        new_state_dict[new_key] = value
    return new_state_dict


def load_and_setup_model(model_name, parser, checkpoint, fp16_run):
    model_parser = models.parse_model_args(model_name, parser, add_help=False)
    model_args, _ = model_parser.parse_known_args()

    model_config = models.get_model_config(model_name, model_args)
    model = models.get_model(model_name, model_config, to_fp16=fp16_run, to_cuda=True, training=False)

    if checkpoint is not None:
        state_dict = torch.load(checkpoint)['state_dict']
        if checkpoint_from_distributed(state_dict):
            state_dict = unwrap_distributed(state_dict)

        model.load_state_dict(state_dict)
    model.eval()

    return model


def main():
    """
    Launches text to speech (inference).
    Inference is executed on a single GPU.
    """
    parser = argparse.ArgumentParser(
        description='PyTorch Tacotron 2 Inference')
    parser = parse_args(parser)
    args, _ = parser.parse_known_args()

    LOGGER.set_model_name("Tacotron2_PyT")
    LOGGER.set_backends([
        dllg.StdOutBackend(log_file=None,
                           logging_scope=dllg.TRAIN_ITER_SCOPE, iteration_interval=1),
        dllg.JsonBackend(log_file=args.log_file,
                         logging_scope=dllg.TRAIN_ITER_SCOPE, iteration_interval=1)
    ])
    LOGGER.register_metric("tacotron2_items_per_sec", metric_scope=dllg.TRAIN_ITER_SCOPE)
    LOGGER.register_metric("waveglow_items_per_sec", metric_scope=dllg.TRAIN_ITER_SCOPE)

    log_hardware()
    log_args(args)

    # tacotron2 model filepath was specified
    if args.tacotron2:
        # Setup Tacotron2
        tacotron2 = load_and_setup_model('Tacotron2', parser, args.tacotron2, args.fp16_run)
    # file with mel spectrogram was specified
    elif args.mel_file:
        mel = torch.load(args.mel_file)
        mel = torch.autograd.Variable(mel.cuda())
        mel = torch.unsqueeze(mel, 0)

    # Setup WaveGlow
    if args.old_waveglow:
        waveglow = torch.load(args.waveglow)['model']
        waveglow = waveglow.remove_weightnorm(waveglow)
        waveglow = waveglow.cuda()
        waveglow.eval()
    else:
        waveglow = load_and_setup_model('WaveGlow', parser, args.waveglow, args.fp16_run)

    texts = []
    try:
        f = open(args.input, 'r')
        texts = f.readlines()
    except:
        print("Could not read file. Using default text.")
        texts = ["The forms of printed letters should be beautiful, and\
        that their arrangement on the page should be reasonable and\
        a help to the shapeliness of the letters themselves."]

    for i, text in enumerate(texts):

        LOGGER.iteration_start()

        sequence = np.array(text_to_sequence(text, ['english_cleaners']))[None, :]
        sequence = torch.autograd.Variable(
            torch.from_numpy(sequence)).cuda().long()

        if args.tacotron2:
            tacotron2_t0 = time.time()
            with torch.no_grad():
                _, mel, _, _ = tacotron2.infer(sequence)
            tacotron2_t1 = time.time()
            tacotron2_infer_perf = sequence.size(1)/(tacotron2_t1-tacotron2_t0)
            LOGGER.log(key="tacotron2_items_per_sec", value=tacotron2_infer_perf)

        waveglow_t0 = time.time()
        with torch.no_grad():
            audio = waveglow.infer(mel, sigma=args.sigma_infer)
            audio = audio.float()
        waveglow_t1 = time.time()
        waveglow_infer_perf = audio[0].size(0)/(waveglow_t1-waveglow_t0)

        audio_path = args.output + "audio_"+str(i)+".wav"
        write(audio_path, args.sampling_rate, audio[0].data.cpu().numpy())

        LOGGER.log(key="waveglow_items_per_sec", value=waveglow_infer_perf)
        LOGGER.iteration_stop()

    LOGGER.finish()

if __name__ == '__main__':
    main()
